Conversation
|
This PR adds:
What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the |
|
Thanks @fnattino, this looks fantastic 🚀 I really like the template to automatically broadcast to the correct shape and device at the beginning, because right now we are doing it quite a lot of times in the integration loops. Ideally, it would be nice to remove all the calls to |
One thing is the naming "
It is Awesome! 🥇 Thanks. I like how things get simpler and cleaner. Just one comment about naming, see above. |
|
Thank you @SCiarella and @SarahAlidoost for the useful feedback!
Indeed, I think it's a good idea to also add similar containers for states and rates, so all variables are initialized with the correct shape and device!
My idea was to use import torch
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.traitlets import Tensor
class Parameters(TensorParamTemplate):
A = Tensor(0.)
B = Tensor(0, dtype=int)
# Parameters A and B are casted into tensors
params = Parameters(dict(A=0., B=0))
params.A
# tensor(0., dtype=torch.float64)
params.B
# tensor(0) |
|
| tmin = _get_drv(drv.TMIN, self.params.shape, dtype=self.dtype, device=self.device) | ||
|
|
||
| # Assimilation is zero before crop emergence (DVS < 0) | ||
| dvs_mask = (dvs >= 0).to(dtype=self.dtype) |
There was a problem hiding this comment.
Was this needed? Why cannot we leave it as a tensor with dtype bool?
| SLA = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device) | ||
| LVAGE = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device) | ||
| LV = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device) |
There was a problem hiding this comment.
I realized the choice of having the time dimension as last axis was probably not the best one. Having the time dimension as first axis instead allows for automatic broadcasting, i.e.:
import torch
# This works fine
x = torch.zeros((10, 5, 5)) + torch.ones((5, 5))
# This does not
x = torch.zeros((5, 5, 10)) + torch.ones((5, 5))
# RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 2So I moved the time axis, which only implies minor changes, and only in this module.
| (*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device | ||
| ) | ||
| LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device) | ||
| SLA[..., 0] = params.SLATB(DVS).to(dtype=self.dtype, device=self.device) |
There was a problem hiding this comment.
The output of the afgen should already have correct dtype and device, correct? So no need of the .to(...) here?
| # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask | ||
| # A mask (0 if DVS < 0, 1 if DVS >= 0) | ||
| DVS = torch.as_tensor(k["DVS"], dtype=self.dtype, device=self.device) | ||
| dvs_mask = (DVS >= 0).to(dtype=self.dtype).to(device=self.device) |
There was a problem hiding this comment.
Same as above, I think we can leave this as a tensor with bool dtype? So we don't have to convert it to bool at lines 434, 441, and 451 below?
| # in DALV. | ||
| # Note that the actual leaf death is imposed on the array LV during the | ||
| # state integration step. | ||
| tSPAN = _broadcast_to( |
There was a problem hiding this comment.
This is no more needed after I moved the time axis as leading dimension (automatic broadcasting)
| # is used. | ||
| span_mask = hard_mask.detach() + soft_mask - soft_mask.detach() | ||
| else: | ||
| span_mask = (s.LVAGE > tSPAN).to(dtype=self.dtype) |
There was a problem hiding this comment.
Conversion to dtype should not be needed here?
| self.params_shape = _get_params_shape(self.params) | ||
|
|
||
| DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) | ||
| DVS = _broadcast_to(self.kiosk["DVS"], self.params.shape) |
There was a problem hiding this comment.
As discussed, here I make sure the input (i.e. DVS) is broadcasted to the correct shape, instead of all outputs. It's only needed in test environments, where DVS is "injected" as external parameter.
| def _on_CROP_START( | ||
| self, day, crop_name=None, variety_name=None, crop_start_type=None, crop_end_type=None | ||
| ): | ||
| """Starts the crop.""" | ||
| self.logger.debug(f"Received signal 'CROP_START' on day {day}") | ||
|
|
||
| if self.crop is not None: | ||
| raise RuntimeError( | ||
| "A CROP_START signal was received while self.cropsimulation still holds a valid " | ||
| "cropsimulation object. It looks like you forgot to send a CROP_FINISH signal with " | ||
| "option crop_delete=True" | ||
| ) | ||
|
|
||
| self.parameterprovider.set_active_crop( | ||
| crop_name, variety_name, crop_start_type, crop_end_type | ||
| ) | ||
| self.crop = self.mconf.CROP(day, self.kiosk, self.parameterprovider, shape=self._shape) |
There was a problem hiding this comment.
I need to redefine this function to pass the shape to the crop model.
| self.mconf = config | ||
|
|
||
| self.parameterprovider = parameterprovider | ||
| self._shape = _get_params_shape(self.parameterprovider) |
There was a problem hiding this comment.
Right now the shape is inferred from the parameters only. But we might infer it also from the weather data later on.
| if x.dim() == 0: | ||
| # For 0-d tensors, we simply broadcast to the given shape | ||
| return torch.broadcast_to(x, shape) | ||
| # The given shape should match x in all but the last axis, which represents | ||
| # the dimension along which the time integration is carried out. | ||
| # We first append an axis to x, then expand to the given shape | ||
| return x.unsqueeze(-1).expand(shape) |
There was a problem hiding this comment.
The last part of this function was only needed when we had to broadcast the tensors in leaf dynamics with the additional time axis. This is not needed anymore, and the function is much "cleaner" now.
| from diffwofost.physical_models.utils import prepare_engine_input | ||
| from . import phy_data_folder | ||
|
|
||
| config = Configuration(CROP=DVS_Phenology) |
There was a problem hiding this comment.
Now the Engine implemented in diffwofost passes the shape to the crop model - so the original phenology model from PCSE cannot be used here anymore.
|
Hi @SarahAlidoost @SCiarella this is now ready to be reviewed. I know it's quite some changes, so I tried to leave as many comments as possible to facilitate the review. |



relates #25